{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ref: https://github.com/onnx/tutorials/blob/master/tutorials/PytorchOnnxExport.ipynb"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "tensorflow backend\n",
    "- alexnet ...X\n",
    "- vgg11 ...O\n",
    "- vgg11_bn...O\n",
    "- densenet121 ..X\n",
    "- inception_v3 ..X\n",
    "- resnet18 ..X\n",
    "- squeezenet1_0 ..X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save pre-trained model to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.autograd import Variable\n",
    "import torch.onnx\n",
    "import torchvision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Standard ImageNet input - 3 channels, 224x224,\n",
    "# values don't matter as we care about network structure.\n",
    "# But they can also be real inputs.\n",
    "dummy_input = Variable(torch.randn(1, 3, 224, 224))\n",
    "# Obtain your model, it can be also constructed in your script explicitly\n",
    "model = torchvision.models.vgg19(pretrained=True)\n",
    "# Invoke export\n",
    "torch.onnx.export(model, dummy_input, \"vgg19.onnx\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load model from ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "graph torch-jit-export (\n",
      "  %1[FLOAT, 1x3x224x224]\n",
      ") initializers (\n",
      "  %2[FLOAT, 64x3x3x3]\n",
      "  %3[FLOAT, 64]\n",
      "  %4[FLOAT, 64x64x3x3]\n",
      "  %5[FLOAT, 64]\n",
      "  %6[FLOAT, 128x64x3x3]\n",
      "  %7[FLOAT, 128]\n",
      "  %8[FLOAT, 128x128x3x3]\n",
      "  %9[FLOAT, 128]\n",
      "  %10[FLOAT, 256x128x3x3]\n",
      "  %11[FLOAT, 256]\n",
      "  %12[FLOAT, 256x256x3x3]\n",
      "  %13[FLOAT, 256]\n",
      "  %14[FLOAT, 256x256x3x3]\n",
      "  %15[FLOAT, 256]\n",
      "  %16[FLOAT, 256x256x3x3]\n",
      "  %17[FLOAT, 256]\n",
      "  %18[FLOAT, 512x256x3x3]\n",
      "  %19[FLOAT, 512]\n",
      "  %20[FLOAT, 512x512x3x3]\n",
      "  %21[FLOAT, 512]\n",
      "  %22[FLOAT, 512x512x3x3]\n",
      "  %23[FLOAT, 512]\n",
      "  %24[FLOAT, 512x512x3x3]\n",
      "  %25[FLOAT, 512]\n",
      "  %26[FLOAT, 512x512x3x3]\n",
      "  %27[FLOAT, 512]\n",
      "  %28[FLOAT, 512x512x3x3]\n",
      "  %29[FLOAT, 512]\n",
      "  %30[FLOAT, 512x512x3x3]\n",
      "  %31[FLOAT, 512]\n",
      "  %32[FLOAT, 512x512x3x3]\n",
      "  %33[FLOAT, 512]\n",
      "  %34[FLOAT, 4096x25088]\n",
      "  %35[FLOAT, 4096]\n",
      "  %36[FLOAT, 4096x4096]\n",
      "  %37[FLOAT, 4096]\n",
      "  %38[FLOAT, 1000x4096]\n",
      "  %39[FLOAT, 1000]\n",
      ") {\n",
      "  %41 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%1, %2)\n",
      "  %42 = Add[axis = 1, broadcast = 1](%41, %3)\n",
      "  %43 = Relu(%42)\n",
      "  %45 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%43, %4)\n",
      "  %46 = Add[axis = 1, broadcast = 1](%45, %5)\n",
      "  %47 = Relu(%46)\n",
      "  %48 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%47)\n",
      "  %50 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%48, %6)\n",
      "  %51 = Add[axis = 1, broadcast = 1](%50, %7)\n",
      "  %52 = Relu(%51)\n",
      "  %54 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%52, %8)\n",
      "  %55 = Add[axis = 1, broadcast = 1](%54, %9)\n",
      "  %56 = Relu(%55)\n",
      "  %57 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%56)\n",
      "  %59 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%57, %10)\n",
      "  %60 = Add[axis = 1, broadcast = 1](%59, %11)\n",
      "  %61 = Relu(%60)\n",
      "  %63 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%61, %12)\n",
      "  %64 = Add[axis = 1, broadcast = 1](%63, %13)\n",
      "  %65 = Relu(%64)\n",
      "  %67 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%65, %14)\n",
      "  %68 = Add[axis = 1, broadcast = 1](%67, %15)\n",
      "  %69 = Relu(%68)\n",
      "  %71 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%69, %16)\n",
      "  %72 = Add[axis = 1, broadcast = 1](%71, %17)\n",
      "  %73 = Relu(%72)\n",
      "  %74 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%73)\n",
      "  %76 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%74, %18)\n",
      "  %77 = Add[axis = 1, broadcast = 1](%76, %19)\n",
      "  %78 = Relu(%77)\n",
      "  %80 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%78, %20)\n",
      "  %81 = Add[axis = 1, broadcast = 1](%80, %21)\n",
      "  %82 = Relu(%81)\n",
      "  %84 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%82, %22)\n",
      "  %85 = Add[axis = 1, broadcast = 1](%84, %23)\n",
      "  %86 = Relu(%85)\n",
      "  %88 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%86, %24)\n",
      "  %89 = Add[axis = 1, broadcast = 1](%88, %25)\n",
      "  %90 = Relu(%89)\n",
      "  %91 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%90)\n",
      "  %93 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%91, %26)\n",
      "  %94 = Add[axis = 1, broadcast = 1](%93, %27)\n",
      "  %95 = Relu(%94)\n",
      "  %97 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%95, %28)\n",
      "  %98 = Add[axis = 1, broadcast = 1](%97, %29)\n",
      "  %99 = Relu(%98)\n",
      "  %101 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%99, %30)\n",
      "  %102 = Add[axis = 1, broadcast = 1](%101, %31)\n",
      "  %103 = Relu(%102)\n",
      "  %105 = Conv[dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [1, 1, 1, 1], strides = [1, 1]](%103, %32)\n",
      "  %106 = Add[axis = 1, broadcast = 1](%105, %33)\n",
      "  %107 = Relu(%106)\n",
      "  %108 = MaxPool[kernel_shape = [2, 2], pads = [0, 0], strides = [2, 2]](%107)\n",
      "  %109 = Reshape[shape = [1, -1]](%108)\n",
      "  %112 = Gemm[alpha = 1, beta = 1, broadcast = 1, transB = 1](%109, %34, %35)\n",
      "  %113 = Relu(%112)\n",
      "  %115, %116 = Dropout[is_test = 1, ratio = 0.5](%113)\n",
      "  %119 = Gemm[alpha = 1, beta = 1, broadcast = 1, transB = 1](%115, %36, %37)\n",
      "  %120 = Relu(%119)\n",
      "  %122, %123 = Dropout[is_test = 1, ratio = 0.5](%120)\n",
      "  %126 = Gemm[alpha = 1, beta = 1, broadcast = 1, transB = 1](%122, %38, %39)\n",
      "  return %126\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import onnx\n",
    "\n",
    "# Load the ONNX model\n",
    "model = onnx.load(\"vgg19.onnx\")\n",
    "\n",
    "# Check that the IR is well formed\n",
    "onnx.checker.check_model(model)\n",
    "\n",
    "# Print a human readable representation of the graph\n",
    "print(onnx.helper.printable_graph(model.graph))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Inference with your prefered toolkit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 3, 224, 224)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from PIL import Image \n",
    "import numpy as np\n",
    "imgfile = './data/cat.jpeg'\n",
    "img = Image.open(imgfile).resize( (224, 224) )\n",
    "img_arr = np.array(img)\n",
    "img_arr = np.expand_dims(img_arr, -1)\n",
    "img_arr = np.transpose(img_arr, (3,2,0,1))/255.0\n",
    "img_arr.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python2.7/dist-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.\n",
      "  from ._conv import register_converters as _register_converters\n",
      "/usr/local/lib/python2.7/dist-packages/onnx_tf/backend.py:677: UserWarning: Unsupported kernel_shape attribute by Tensorflow in Conv operator. The attribute will be ignored.\n",
      "  UserWarning)\n"
     ]
    }
   ],
   "source": [
    "import onnx_tf.backend as backend\n",
    "\n",
    "rep = backend.prepare(model, device=\"CUDA:0\") # or \"CPU\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = rep.run(img_arr.astype(np.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1, 1000)\n"
     ]
    }
   ],
   "source": [
    "outputs = np.array(outputs).squeeze(0)\n",
    "print(outputs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1000,)\n"
     ]
    }
   ],
   "source": [
    "outputs = np.array(outputs).squeeze(0)\n",
    "print(outputs.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imagenet1000_clsid_to_human import cls_dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "str_ = \"\"\n",
    "for i in reversed(np.argsort(outputs)[-5:]):\n",
    "    str_ +=  \"{:.2f}% : {} \\n\".format(outputs[i], cls_dict[i])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "13.23% : Egyptian cat \n",
      "10.91% : tabby, tabby cat \n",
      "10.35% : tiger cat \n",
      "7.97% : lynx, catamount \n",
      "7.57% : vase \n",
      "\n"
     ]
    }
   ],
   "source": [
    "print (str_)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
